import onnxruntime
from .base import Backend
import numpy as np
from geartrain.utils.import_utils import is_torch_available

if is_torch_available():
    import torch


class ORT_Backend(Backend):
    def __init__(self, model_path=None, device="cpu", **kwargs) -> None:
        self.model_path = model_path
        self.fp16 = False
        super().__init__(device, "ort")

    def _load_model(self):
        cuda = self.device == "cuda"
        providers = (
            ["CUDAExecutionProvider", "CPUExecutionProvider"]
            if cuda
            else ["CPUExecutionProvider"]
        )
        self.session = onnxruntime.InferenceSession(
            self.model_path, providers=providers
        )
        self.output_names = [x.name for x in self.session.get_outputs()]
        metadata = self.session.get_modelmeta().custom_metadata_map

    # def forward(self, im, augment=False, visualize=False, embed=None):
    def _forward(self, model_inputs):
        if not isinstance(model_inputs, np.ndarray):
            _inputs = model_inputs.cpu().numpy()
        else:
            _inputs = model_inputs

        # np.save("bus_input.npy", _inputs)

        im = model_inputs
        b, ch, h, w = im.shape  # batch, channel, height, width
        # if is_torch_available():
        #     if self.fp16 and im.dtype != torch.float16:
        #         im = im.half()  # to FP16

        if is_torch_available() and isinstance(im, torch.Tensor):
            im = im.cpu().numpy()  # torch to numpy

        im = (
            im.astype(np.float16)
            if self.fp16 and im.dtype != np.float16
            else im.astype(np.float32)
        )
        # import ipdb
        # ipdb.set_trace()
        y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
        ##################################
        # YOLOv8:
        # if len(y) == 2:  # segment with (det, proto) output order reversed
        #     if len(y[1].shape) != 4:
        #         y = list(reversed(y))  # should be y = (1, 116, 8400), (1, 160, 160, 32)
        # y[1] = np.transpose(y[1], (0, 3, 1, 2))  # should be y = (1, 116, 8400), (1, 32, 160, 160)
        ##################################

        y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]

        # for x in y:
        #     print(type(x), len(x)) if isinstance(x, (list, tuple)) else print(type(x), x.shape)  # debug shapes

        if isinstance(y, (list, tuple)):
            return (
                self.from_numpy(y[0])
                if len(y) == 1
                else [self.from_numpy(x) for x in y]
            )
        else:
            return self.from_numpy(y)

    def from_numpy(self, x):
        """
        Convert a numpy array to a tensor.

        Args:
            x (np.ndarray): The array to be converted.

        Returns:
            (torch.Tensor): The converted tensor
        """
        if is_torch_available():
            return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x
        else:
            return x
